import numpy as np
import torch
import random
import math
import torch.nn.functional as F
import matplotlib.pyplot as plt
import re
import torch.nn as nn
import copy
import warnings
import copy

a = torch.tensor([
    [0.6, .3, 0.2, 0.3, 0.1],
    [0.6, .7, 0.2, 0.3, 0.1],
    [0.1, .7, 0.2, 0.3, 0.1],
    [0.6, .3, 0.2, 0.3, 0.1],
    [0.6, .3, 0.2, 0.3, 0.1],
    [0.6, .3, 0.2, 0.3, 0.1],
    [0.5, .3, 0.2, 0.3, 0.1],  # 也有可能那个都没有
])
b = a >=.7
print(b)
print(b.any(1))  # 从输出结果来看1是按列查找，找到任何一个符合的列
print(b.any(0))  # 0是按行查找找到，只找到一行即可